context("ml clustering - bisecting kmeans")

sc <- testthat_spark_connection()
test_requires("dplyr")
data(iris)

test_that("ml_bisecting_kmeans param setting", {
  test_requires_version("2.0.0", "bisecting kmeans support")
  args <- list(
    x = sc, k = 9, max_iter = 11, min_divisible_cluster_size = 3,
    seed = 98, features_col = "fcol",
    prediction_col = "pcol"
  )
  predictor <- do.call(ml_bisecting_kmeans, args)
  args_to_check <- setdiff(names(args), "x")

  expect_equal(ml_params(predictor, args_to_check), args[args_to_check])
})

test_that("ml_bisecting_kmeans() default params are correct", {
  test_requires_version("2.0.0", "bisecting kmeans support")
  predictor <- ml_pipeline(sc) %>%
    ml_bisecting_kmeans() %>%
    ml_stage(1)

  args <- get_default_args(
    ml_bisecting_kmeans,
    c("x", "uid", "...", "seed"))

  expect_equal(
    ml_params(predictor, names(args)),
    args)
})

test_that("ml_bisecting_kmeans() works properly", {
  test_requires_version("2.0.0", "bisecting kmeans support")
  sample_data_path <- dir(getwd(), recursive = TRUE, pattern = "sample_libsvm_data.txt", full.names = TRUE)

  sample_data <- spark_read_libsvm(sc, "sample_data",
                                   sample_data_path, overwrite = TRUE)
  bkm <- ml_bisecting_kmeans(sample_data, k = 2, seed = 1)
  expect_equal(bkm$compute_cost(sample_data), 214807298)

  cluster_centers <- list(c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.08474576271186,
                            3.23728813559322, 4, 5.44067796610169, 5.6271186440678, 1.30508474576271,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.44067796610169,
                            4.28813559322034, 2.23728813559322, 0.152542372881356, 6.3728813559322,
                            27.0508474576271, 38.1864406779661, 44.2203389830508, 45.4406779661017,
                            49.9830508474576, 67.9491525423729, 53.8474576271186, 34.9830508474576,
                            10.0338983050847, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.38983050847458,
                            4.08474576271186, 4.25423728813559, 4.25423728813559, 3.32203389830508,
                            10.7796610169492, 45.6271186440678, 64.3050847457627, 69.8305084745763,
                            81.0508474576271, 89.864406779661, 103.593220338983, 80.0677966101695,
                            64.3728813559322, 22.8983050847458, 3.83050847457627, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 2.96610169491525, 4.25423728813559,
                            4.25423728813559, 4.25423728813559, 5.38983050847458, 18.864406779661,
                            46.8135593220339, 66.728813559322, 85.7627118644068, 95.8305084745763,
                            117, 113.915254237288, 94.6440677966102, 66.8813559322034, 21.7118644067797,
                            6.38983050847458, 0.135593220338983, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0.220338983050847, 3.45762711864407, 4.25423728813559,
                            5.49152542372881, 8.25423728813559, 21.1864406779661, 48.864406779661,
                            68.2033898305085, 99.0169491525424, 109.135593220339, 132.728813559322,
                            113.64406779661, 96.9830508474576, 61.5084745762712, 20.1694915254237,
                            7.1864406779661, 1.15254237288136, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0.661016949152542, 4.25423728813559, 7.33898305084746,
                            10.271186440678, 21.9661016949153, 47.5593220338983, 74.2203389830509,
                            117.033898305085, 140.932203389831, 142.64406779661, 117.305084745763,
                            93.1864406779661, 45.4237288135593, 9.94915254237288, 3.49152542372881,
                            1.69491525423729, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.254237288135593,
                            2.89830508474576, 9.71186440677966, 19, 25.3220338983051, 47.0677966101695,
                            79.4237288135593, 139.745762711864, 159.423728813559, 150.779661016949,
                            109.694915254237, 76.4745762711864, 29.8135593220339, 7.16949152542373,
                            2.77966101694915, 1.15254237288136, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            0, 0, 0, 0.355932203389831, 5.49152542372881, 11.3220338983051,
                            14.8474576271186, 18.4576271186441, 42, 90.271186440678, 167.728813559322,
                            184.661016949153, 148.64406779661, 95.0847457627119, 45.9661016949153,
                            15.7627118644068, 1.72881355932203, 3.89830508474576, 4.30508474576271,
                            3.15254237288136, 0.23728813559322, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                            1.49152542372881, 7.27118644067797, 10.0677966101695, 8.91525423728813,
                            14.8474576271186, 16.0169491525424, 34.2881355932203, 107.779661016949,
                            192.728813559322, 199.152542372881, 137.423728813559, 82, 28.8135593220339,
                            6.64406779661017, 0, 0.559322033898305, 3.03389830508475, 4.28813559322034,
                            2.3728813559322, 0, 0, 0, 0, 0, 0, 0, 0, 0.796610169491525, 7.1864406779661,
                            11.3220338983051, 8.06779661016949, 5.91525423728814, 10.271186440678,
                            10.6949152542373, 35.4745762711864, 126.576271186441, 215.372881355932,
                            204.949152542373, 131.237288135593, 61.7966101694915, 12.9830508474576,
                            4.28813559322034, 0, 0, 0.406779661016949, 3.54237288135593,
                            3.77966101694915, 0, 0, 0, 0, 0, 0, 0, 0, 3.86440677966102, 8.54237288135593,
                            4.23728813559322, 0.983050847457627, 1.13559322033898, 4.91525423728814,
                            5, 40.5762711864407, 173.440677966102, 231.186440677966, 206.271186440678,
                            115.338983050847, 34.5084745762712, 6.22033898305085, 3.91525423728814,
                            0, 0, 0, 2.30508474576271, 4.28813559322034, 0, 0, 0, 0, 0, 0,
                            0, 0, 6.13559322033898, 4.32203389830508, 0.254237288135593,
                            0, 1.66101694915254, 4.30508474576271, 4.30508474576271, 60.3050847457627,
                            201.186440677966, 234.28813559322, 197.35593220339, 79.1525423728814,
                            22.2372881355932, 2.8135593220339, 2.6271186440678, 0, 0, 0,
                            2.30508474576271, 3.64406779661017, 0, 0, 0, 0, 0, 0, 0, 0, 4.32203389830508,
                            1.40677966101695, 0.932203389830508, 0, 1.66101694915254, 4.69491525423729,
                            12.8983050847458, 100.898305084746, 217.796610169492, 234.033898305085,
                            165.915254237288, 48.5762711864407, 10.7796610169492, 4.03389830508475,
                            2.6271186440678, 0, 0.0847457627118644, 2.15254237288136, 4.16949152542373,
                            2.25423728813559, 0, 0, 0, 0, 0, 0, 0, 0, 4.55932203389831, 2.47457627118644,
                            0.389830508474576, 0, 1.66101694915254, 9.79661016949153, 41.6271186440678,
                            124.406779661017, 214.203389830508, 223.237288135593, 134.322033898305,
                            36.5593220338983, 5.96610169491525, 7.71186440677966, 5.47457627118644,
                            0.288135593220339, 2.54237288135593, 4.1864406779661, 1.54237288135593,
                            0.152542372881356, 0, 0, 0, 0, 0, 0, 0, 0, 5.27118644067797,
                            4.20338983050847, 0, 0, 2.45762711864407, 25, 75.593220338983,
                            136.271186440678, 211.983050847458, 196.949152542373, 105.457627118644,
                            31.9322033898305, 7.22033898305085, 11.5593220338983, 10.6779661016949,
                            5.76271186440678, 3.66101694915254, 1.42372881355932, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 1.88135593220339, 4.25423728813559, 1.47457627118644,
                            0, 6.49152542372881, 49.5593220338983, 102.881355932203, 149.152542372881,
                            202.203389830508, 166.71186440678, 89.4237288135593, 33.6101694915254,
                            16.8813559322034, 17.0338983050847, 13.5762711864407, 8.03389830508475,
                            1.74576271186441, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.23728813559322,
                            3.13559322033898, 4.08474576271186, 6.30508474576271, 25.2203389830508,
                            79.7118644067797, 116.627118644068, 153.237288135593, 173.322033898305,
                            148.830508474576, 85.6271186440678, 46.6779661016949, 25.5762711864407,
                            17.4745762711864, 14.1016949152542, 9.33898305084746, 2.33898305084746,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.389830508474576, 4.44067796610169,
                            18.8813559322034, 50.1016949152542, 92.7796610169491, 132.847457627119,
                            149, 144.694915254237, 126.338983050847, 76.2372881355932, 47.3050847457627,
                            30.0169491525424, 18.9322033898305, 8.30508474576271, 6.69491525423729,
                            3.23728813559322, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4.23728813559322,
                            11.2372881355932, 25.9661016949153, 68, 99.7118644067797, 127.864406779661,
                            127.728813559322, 123.864406779661, 107.372881355932, 77.728813559322,
                            47.5932203389831, 26.9830508474576, 14.8305084745763, 7.76271186440678,
                            4.93220338983051, 2.93220338983051, 0, 0, 0, 0, 2.3728813559322,
                            0, 0, 0, 0, 0, 0, 0, 3.45762711864407, 13.3220338983051, 37.7627118644068,
                            81.5593220338983, 105.813559322034, 119.322033898305, 105.033898305085,
                            109.237288135593, 98.9152542372881, 68.4237288135593, 36.8813559322034,
                            24.5423728813559, 17.6949152542373, 14.6271186440678, 5.94915254237288,
                            2.93220338983051, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.203389830508475,
                            5.89830508474576, 26.4237288135593, 75.2542372881356, 91.1186440677966,
                            84.4576271186441, 60.9491525423729, 63.3728813559322, 63.5593220338983,
                            40.3898305084746, 20.4406779661017, 18.3728813559322, 14.5932203389831,
                            7.20338983050847, 4.59322033898305, 2.93220338983051, 0, 0, 0,
                            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.949152542372881, 20.4915254237288,
                            26.4067796610169, 19.5084745762712, 9.49152542372881, 4.98305084745763,
                            14.6949152542373, 10.5762711864407, 0.525423728813559, 0, 0,
                            0), c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.36585365853659, 6.02439024390244,
                                  2.95121951219512, 1.70731707317073, 6.21951219512195, 4.02439024390244,
                                  2.78048780487805, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                  0, 0, 0, 0, 1.53658536585366, 5.92682926829268, 13.0487804878049,
                                  43.6829268292683, 65.1951219512195, 76.5609756097561, 91.1463414634146,
                                  97.2926829268293, 92.7560975609756, 65.4390243902439, 25.4146341463415,
                                  6.80487804878049, 1.92682926829268, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                                  0, 0, 0, 0, 0, 1.73170731707317, 8.41463414634146, 25.7317073170732,
                                  55.8780487804878, 88, 123.219512195122, 166.439024390244, 195.512195121951,
                                  192.317073170732, 173.390243902439, 142.365853658537, 100.536585365854,
                                  48.0731707317073, 18.0975609756098, 3.46341463414634, 2.07317073170732,
                                  0.341463414634146, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4.21951219512195,
                                  13.6829268292683, 34.0731707317073, 69.5609756097561, 87.0731707317073,
                                  129.951219512195, 171.707317073171, 205.19512195122, 224.024390243902,
                                  219.268292682927, 199.926829268293, 188.707317073171, 152.853658536585,
                                  96.1219512195122, 43.219512195122, 13.7073170731707, 6.09756097560976,
                                  3.5609756097561, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3.19512195121951,
                                  16.0731707317073, 36.2926829268293, 62.1951219512195, 87.390243902439,
                                  125.878048780488, 167.09756097561, 211.80487804878, 211.268292682927,
                                  202.146341463415, 211.90243902439, 223.439024390244, 208.585365853659,
                                  205.585365853659, 159.317073170732, 78.9024390243902, 21.3170731707317,
                                  7.02439024390244, 6.09756097560976, 0, 0, 0, 0, 0, 0, 0, 0, 0.560975609756098,
                                  11.7317073170732, 28.9268292682927, 60.9512195121951, 83.1951219512195,
                                  129.09756097561, 167, 187.121951219512, 198.682926829268, 171.609756097561,
                                  147, 149.19512195122, 183.853658536585, 187.658536585366, 209.487804878049,
                                  189.756097560976, 122.365853658537, 48.2439024390244, 15.6585365853659,
                                  6.41463414634146, 0, 0, 0, 0, 0, 0, 0, 0, 4.70731707317073, 23.3414634146341,
                                  42.6585365853659, 77.8048780487805, 111.878048780488, 150.634146341463,
                                  190.073170731707, 173.536585365854, 156.073170731707, 125.414634146341,
                                  93.3170731707317, 91.0975609756098, 110.512195121951, 134.292682926829,
                                  179.658536585366, 196.390243902439, 148.853658536585, 89.5609756097561,
                                  26.8780487804878, 9.95121951219512, 0, 0, 0, 0, 0, 0, 0, 0, 7.8780487804878,
                                  34.8048780487805, 63.6585365853659, 96.2682926829268, 142.634146341463,
                                  167.634146341463, 165.024390243902, 150.121951219512, 120, 89.2682926829268,
                                  55.4634146341463, 50.8292682926829, 66.8780487804878, 91.9268292682927,
                                  142.365853658537, 191.048780487805, 161.80487804878, 112.926829268293,
                                  48, 14.9024390243902, 0, 0, 0, 0, 0, 0, 0, 0, 20.8536585365854,
                                  55.780487804878, 84.780487804878, 114.90243902439, 158.317073170732,
                                  155.487804878049, 147.121951219512, 122.439024390244, 80.8780487804878,
                                  47.6341463414634, 26.6341463414634, 25.2439024390244, 37.1463414634146,
                                  64.6585365853659, 104.414634146341, 162.219512195122, 160.146341463415,
                                  133.121951219512, 72.6585365853659, 23.7317073170732, 0, 0, 0,
                                  0, 0, 0, 0, 0, 33.390243902439, 72.4390243902439, 101.19512195122,
                                  144.756097560976, 177.951219512195, 150.146341463415, 120.024390243902,
                                  84.5121951219512, 50.4878048780488, 22.7073170731707, 11.609756097561,
                                  11.8780487804878, 26.9756097560976, 53.9512195121951, 81.5121951219512,
                                  149.878048780488, 158.121951219512, 143.048780487805, 91.4878048780488,
                                  32.6585365853659, 0, 0, 0, 0, 0, 0, 0, 0, 36.8780487804878, 89.2439024390244,
                                  126.585365853659, 161.829268292683, 173.146341463415, 130, 91.0487804878049,
                                  61.8780487804878, 22.8780487804878, 8.78048780487805, 6.68292682926829,
                                  3.34146341463415, 16.4390243902439, 56.5365853658537, 79.1219512195122,
                                  143.512195121951, 162.243902439024, 152.09756097561, 93.7560975609756,
                                  37.3658536585366, 0, 0, 0, 0, 0, 0, 0, 0.463414634146341, 46.3170731707317,
                                  107.829268292683, 147.073170731707, 174.024390243902, 152.048780487805,
                                  98.1707317073171, 63.5121951219512, 25.5853658536585, 6.73170731707317,
                                  2.24390243902439, 1.73170731707317, 0.24390243902439, 16.0975609756098,
                                  46.1951219512195, 89.1463414634146, 139.317073170732, 157.878048780488,
                                  135.536585365854, 93.5365853658537, 41, 0, 0, 0, 0, 0, 0, 0,
                                  0.829268292682927, 66.3658536585366, 129.487804878049, 162.19512195122,
                                  184.878048780488, 145.048780487805, 82.9756097560976, 30.1951219512195,
                                  7, 0, 0, 0, 0, 14.609756097561, 45.9512195121951, 97.5121951219512,
                                  147.585365853659, 147.853658536585, 123.487804878049, 87.219512195122,
                                  36.5121951219512, 0, 0, 0, 0, 0, 0, 0, 1.29268292682927, 80.4634146341463,
                                  139.951219512195, 179.365853658537, 183.024390243902, 129.317073170732,
                                  66.6341463414634, 21.219512195122, 6.34146341463415, 0, 0, 0,
                                  2.04878048780488, 13.8048780487805, 62.390243902439, 114.682926829268,
                                  150.292682926829, 149.292682926829, 100.512195121951, 76.3658536585366,
                                  33.7317073170732, 0, 0, 0, 0, 0, 0, 0, 6.65853658536585, 84.1219512195122,
                                  144.292682926829, 186.975609756098, 178.90243902439, 130.317073170732,
                                  55.780487804878, 19.0487804878049, 9.09756097560976, 0.268292682926829,
                                  0.75609756097561, 2.46341463414634, 13.7073170731707, 33.9756097560976,
                                  90.7073170731707, 137.756097560976, 161.414634146341, 130.951219512195,
                                  87.219512195122, 67.0243902439024, 26.8780487804878, 0, 0, 0,
                                  0, 0, 0, 0, 6.73170731707317, 74.9756097560976, 135.243902439024,
                                  193.878048780488, 188.780487804878, 132.170731707317, 54.6829268292683,
                                  17.3658536585366, 9.63414634146342, 5.14634146341463, 8.17073170731707,
                                  15.8292682926829, 36.7073170731707, 88.0243902439024, 132.414634146341,
                                  156.560975609756, 145.780487804878, 99.3414634146341, 73.9268292682927,
                                  48.2439024390244, 16.1463414634146, 0, 0, 0, 0, 0, 0, 0, 4.46341463414634,
                                  56.2682926829268, 123.780487804878, 186.390243902439, 207.365853658537,
                                  163.292682926829, 95.609756097561, 50.5609756097561, 26.2439024390244,
                                  23.0243902439024, 34.0243902439024, 51.5365853658537, 89.390243902439,
                                  139.048780487805, 161.073170731707, 160.780487804878, 117.512195121951,
                                  90.5609756097561, 57.8048780487805, 33.5853658536585, 6.17073170731707,
                                  0, 0, 0, 0, 0, 0, 0, 2.48780487804878, 31.7560975609756, 98.7073170731707,
                                  160.585365853659, 210.121951219512, 197.317073170732, 156.073170731707,
                                  120.463414634146, 105.731707317073, 113.414634146341, 115.780487804878,
                                  135.926829268293, 155.756097560976, 175.707317073171, 162.365853658537,
                                  125.048780487805, 100.09756097561, 66.780487804878, 43.3658536585366,
                                  14.8780487804878, 0, 0, 0, 0, 0, 0, 0, 0, 1.78048780487805, 13.4634146341463,
                                  57.1463414634146, 132.463414634146, 191.414634146341, 224.756097560976,
                                  217.487804878049, 193.512195121951, 187.634146341463, 181.292682926829,
                                  185.975609756098, 196.439024390244, 197.439024390244, 177.170731707317,
                                  141.878048780488, 95.1707317073171, 70.4146341463415, 46, 21.1463414634146,
                                  1.65853658536585, 0, 0, 0, 0, 0, 0, 0, 0, 1.51219512195122, 7.75609756097561,
                                  22.4878048780488, 68.8292682926829, 135.829268292683, 191.829268292683,
                                  226.536585365854, 232, 222.975609756098, 223.829268292683, 216.707317073171,
                                  201.90243902439, 175.170731707317, 130.731707317073, 89.0975609756098,
                                  65.6585365853659, 42.2682926829268, 18.2926829268293, 4.07317073170732,
                                  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2.02439024390244, 8.68292682926829,
                                  22, 52.4146341463415, 106.780487804878, 154.756097560976, 183.829268292683,
                                  190.268292682927, 175.829268292683, 163.707317073171, 117.682926829268,
                                  83.2926829268293, 64.7560975609756, 43.2682926829268, 27.4634146341463,
                                  11.0243902439024, 2.34146341463415, 0.24390243902439, 0, 0, 0,
                                  0, 0, 0, 0, 0, 0, 0, 0, 1.8780487804878, 11.9512195121951, 12.6341463414634,
                                  18.3658536585366, 26.3414634146341, 43.4878048780488, 49.9512195121951,
                                  49.9268292682927, 47.9268292682927, 33.8048780487805, 20.1463414634146,
                                  16, 8.90243902439024, 2.65853658536585))

  expect_equal(bkm$cluster_centers, cluster_centers)
})


test_that("ml_bisecting_kmeans() works properly", {
  test_requires_version("2.0.0", "ml_bisecting_kmeans() requires Spark 2.0.0+")
  iris_tbl <- testthat_tbl("iris")
  expect_output_file(
    print(ml_bisecting_kmeans(iris_tbl, ~ . - Species, k = 5, seed = 11)),
    output_file("print/bisecting-kmeans.txt")
  )
})
